Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Oct 24, 2025

Non-vmap creation of masks. These work with all our base masks and we only default back to vmap when using patterns we cannot guarantee (i.e. additional and/or masks).

Note:

  • Non-vmap works with every mask that has anything index based
  • Merged old/new sdpa under one function --> easier maintenance imo
  • Executorch does not need an additional masking fn anymore
  • Lifts some restrictions on older torch versions, e.g. chunked attn with padding, packed attn masks etc

Fixes #41639

cc @jiqing-feng @IlyasMoutawwakil

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu vasqu changed the title [WIP][Masking] Non-vmap default for attention masks [Attn Masks] Non-vmap default for attention masks Oct 29, 2025
@vasqu vasqu marked this pull request as ready for review October 29, 2025 11:06
return cache


def sdpa_mask_without_vmap(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed as vmap was the reason we needed this workaround in the first place

NOTE: It is important to keep an index-based version for non-vmap expansion.
"""
return q_idx.new_ones((), dtype=torch.bool)
return q_idx >= 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted above, for non-vmap we need this as index based version

causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask

attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I encountered issues with the inplace version where we'd need a clone (e.g. when using swa). This is safer

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test to default test no graph break on this?

@jiqing-feng
Copy link
Contributor

Hi @vasqu . Anything blocks merge?

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very very nice, this solves quite a lot of different issues at the same time! I'm very happy to avoid special handling for export! Very clever use of broadcasting from the optimum team, I did not know we could simply do such things! Thanks a lot for upstreaming directly to us.

Do you mind expanding a bit more on what are the limitations of the broadcasting approach here for posterity? Is it only index-based operations as you mention on the comments, or are there more subtle things?

@IlyasMoutawwakil
Copy link
Member

IlyasMoutawwakil commented Nov 10, 2025

Do you mind expanding a bit more on what are the limitations of the broadcasting approach here for posterity? Is it only index-based operations as you mention on the comments, or are there more subtle things?

Nothing I'm aware of, the only condition is to write the mask_function as a comparison between the indexes (and constants).
One example of this is the bidirectional_mask_function which was written as q_idx.new_ones((), dtype=torch.bool) simply returns a scalar 1 (true) and @vasqu rewrote it as q_idx >= 0 (always true). I think any mask function can be written as f(b, h, q, kv) but I can't prove it 😂

@vasqu
Copy link
Contributor Author

vasqu commented Nov 10, 2025

Merging this then! Let's see what crazy masks come up in the future; for now the "mask hypothesis" holds 😆

@vasqu vasqu merged commit 03538a8 into huggingface:main Nov 10, 2025
23 checks passed
@vasqu vasqu deleted the non-vmap-masks branch November 10, 2025 15:04
@vasqu
Copy link
Contributor Author

vasqu commented Nov 10, 2025

@jiqing-feng was mainly blocked by me being out last week ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MarianMTModel performance regression due to Bidirectional masks

6 participants